{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5d6b30e2-925e-40b5-9116-7627827da690",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. Import necessary libraries\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from PyEMD import EMD\n",
    "from scipy.stats import skew\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "import joblib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "359872eb-a624-4b71-98b4-3a616d9c66c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. Define constants\n",
    "sampling_rate = 256  # Hz\n",
    "window_size = 768    # One segment\n",
    "max_imf = 5\n",
    "low_dir = r\"D:\\PPG Dataset\\Low_MWL\\Low_MWL\"\n",
    "high_dir = r\"D:\\PPG Dataset\\High_MWL\\High_MWL\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f644314f-5299-4e8e-b5e5-540c5e46800d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. Define function to extract features from IMF1\n",
    "def extract_features_from_segment(segment):\n",
    "    emd = EMD(spline_kind='cubic', MAX_ITERATION=100)\n",
    "    imfs = emd(segment, max_imf=max_imf)\n",
    "    if imfs.shape[0] > 0:\n",
    "        imf1 = imfs[0]\n",
    "        return [np.mean(imf1), np.min(imf1), np.max(imf1), skew(imf1)]\n",
    "    else:\n",
    "        return [0, 0, 0, 0]  # fallback\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "014b4f2d-df79-4c6e-bed7-588759c9ad97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4. Loop through all files and extract features\n",
    "def process_directory(directory_path, label):\n",
    "    features = []\n",
    "    for file in os.listdir(directory_path):\n",
    "        if file.endswith(\".csv\"):\n",
    "            path = os.path.join(directory_path, file)\n",
    "            df = pd.read_csv(path)\n",
    "            for col in df.columns:\n",
    "                signal = df[col].dropna().values\n",
    "                num_segments = len(signal) // window_size\n",
    "                for i in range(num_segments):\n",
    "                    segment = signal[i*window_size:(i+1)*window_size]\n",
    "                    if len(segment) == window_size:\n",
    "                        feats = extract_features_from_segment(segment)\n",
    "                        feats.append(label)\n",
    "                        features.append(feats)\n",
    "    return features\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "68b110e2-3c8c-441b-95a5-9b6dd918fca3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 5. Process both drowsy and burst sets\n",
    "low_features = process_directory(low_dir, label=0)   # drowsy\n",
    "high_features = process_directory(high_dir, label=1) # burst\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ae6a15ba-b1af-4bd4-836c-9e1abf9eb446",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 6. Combine and create DataFrame\n",
    "all_data = pd.DataFrame(low_features + high_features,\n",
    "                        columns=[\"Imf1_Mean\", \"Imf1_Min\", \"Imf1_Max\", \"Imf1_Skew\", \"Label\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3e33c179-04cf-4314-a584-c69c6ab187f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 7. Save features (optional)\n",
    "all_data.to_csv(\"imf_drowsy_burst_data.csv\", index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8007b1fc-5cdd-4808-8620-bfbb99487ecc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RandomForestClassifier(random_state=42)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 8. Split and train model\n",
    "X = all_data.drop(\"Label\", axis=1)\n",
    "y = all_data[\"Label\"]\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "\n",
    "rf_model = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "rf_model.fit(X_train, y_train)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ed9d4042-dbe3-4df5-8c60-6b15831d359f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "🧠 MLP Classifier\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "      Drowsy       0.59      0.57      0.58       871\n",
      "       Burst       0.59      0.61      0.60       889\n",
      "\n",
      "    accuracy                           0.59      1760\n",
      "   macro avg       0.59      0.59      0.59      1760\n",
      "weighted avg       0.59      0.59      0.59      1760\n",
      "\n",
      "Accuracy: 0.5903409090909091\n"
     ]
    }
   ],
   "source": [
    "mlp = MLPClassifier(hidden_layer_sizes=(64, 32), max_iter=300, random_state=42)\n",
    "mlp.fit(X_train, y_train)\n",
    "mlp_preds = mlp.predict(X_test)\n",
    "print(\"\\n🧠 MLP Classifier\")\n",
    "print(classification_report(y_test, mlp_preds, target_names=[\"Drowsy\", \"Burst\"]))\n",
    "print(\"Accuracy:\", accuracy_score(y_test, mlp_preds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "7d344660-818c-457d-9286-2197caf62a42",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ MLP model saved successfully as 'drowsy_burst_mlp_model.joblib'\n"
     ]
    }
   ],
   "source": [
    "# Save the trained MLP model\n",
    "joblib.dump(mlp, 'drowsy_burst_mlp_model.joblib')\n",
    "print(\"✅ MLP model saved successfully as 'drowsy_burst_mlp_model.joblib'\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41d60e7f-be1d-4609-ab38-772b2382a96e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
